library(prodlim)
library(ape)
library(ggrepel)
library(gmodels)
library(igraph)
library(gridExtra)
library(grid)
library(tidyverse)


first <- T

allowed.ancestry <- c("anc", "multiAnc", "desc", "multiDesc", "complex")
allowed.probable.pair <- c("multiAnc", "multiDesc", "complex")

for(i in 1:100){
  cat("Opening file ",paste0("classifications/Classification_Posterior_ds_10percol_",i,".csv"), "\n", sep="")
  classification <- read_csv(paste0("classifications/Classification_Posterior_ds_10percol_",i,".csv"))
  classification$orig.host.1 <- substr(classification$host.1, 1, 4)
  classification$orig.host.2 <- substr(classification$host.2, 1, 4)
  
  if(first){
    first <- F
    patients <- unique(c(classification$orig.host.1, classification$orig.host.2))
    patients <- patients[order(patients)]
    anc.matrix.ancestry <- matrix(NA, length(patients), length(patients))
    anc.matrix.probable.pair <- matrix(NA, length(patients), length(patients))
  }
  
  for(p.ind.1 in 1:length(patients)){
    for(p.ind.2 in 1:length(patients)){
      if(p.ind.1>p.ind.2){
        
        if(is.na(anc.matrix.ancestry[p.ind.1, p.ind.2])){
          anc.matrix.ancestry[p.ind.1, p.ind.2] <- 0
        }
        
        if(is.na(anc.matrix.probable.pair[p.ind.1, p.ind.2])){
          anc.matrix.probable.pair[p.ind.1, p.ind.2] <- 0
        }
        
        pat.1 <- patients[p.ind.1]
        pat.2 <- patients[p.ind.2]
        
        f <- classification[which(classification$orig.host.1 == pat.1 & classification$orig.host.2 == pat.2 & classification$adjacent & classification$path.classification %in% allowed.ancestry),]
        
        b <- classification[which(classification$orig.host.1 == pat.2 & classification$orig.host.2 == pat.1 & classification$adjacent & classification$path.classification %in% allowed.ancestry),]
        
        f2 <- classification[which(classification$orig.host.1 == pat.1 & classification$orig.host.2 == pat.2 & classification$adjacent & classification$path.classification %in% allowed.probable.pair),]
        
        b2 <- classification[which(classification$orig.host.1 == pat.2 & classification$orig.host.2 == pat.1 & classification$adjacent & classification$path.classification %in% allowed.probable.pair),]
        
        
        anc.matrix.ancestry[p.ind.1, p.ind.2] <- anc.matrix.ancestry[p.ind.1, p.ind.2] + (nrow(f) > 0 | nrow(b) > 0)
        anc.matrix.probable.pair[p.ind.1, p.ind.2] <- anc.matrix.probable.pair[p.ind.1, p.ind.2] + (nrow(f2) > 0 | nrow(b2) > 0)
      }
    }
  }
}

rownames(anc.matrix.ancestry) <- colnames(anc.matrix.ancestry) <- patients
rownames(anc.matrix.probable.pair) <- colnames(anc.matrix.probable.pair) <- patients

goodbye <- read_csv("trace_colonisation_patients.csv") %>% pull(id)

snp.alignment <- read.dna("alignment_nogubbins.fas", format="fasta")
snp.alignment <- snp.alignment[which(startsWith(rownames(snp.alignment), "T")),]

snp.alignment <- snp.alignment[which(!(rownames(snp.alignment) %in% c("T330_N01_C08", "T249_N03_C02", "T234_W02_C01", "T035_N04_C02"))),] 

renamer <- read_csv("new_tip_names.csv")

tibble.for.samples <- tibble(old.label = rownames(snp.alignment)) %>%
  mutate(prob.contaminant = map_lgl(old.label, function (x) x %in% c("T330_N01_C08", "T249_N03_C02", "T234_W02_C01", "T035_N04_C02"))) %>%
  mutate(new.label = map_chr(old.label, function(x) {
    check <- renamer$new.label[match(x, renamer$tip.name)]
    if(!is.na(check)){
      check
    } else {
      x
    }
    })) %>%
  mutate(colony = map_chr(new.label, function(x) str_split_fixed(x, "_", 3)[1])) %>%
  mutate(patient = map_chr(old.label, function(x) str_split_fixed(x, "_", 3)[1])) %>%
  mutate(trace = map_lgl(colony, function(x) x %in% goodbye)) %>%
  group_by(patient) %>%
  mutate(all.trace = all(trace)) %>%
  ungroup() %>%
  mutate(event.number = map_int(old.label, function(x){
    as.integer(str_sub(str_split_fixed(x, "_", 3)[2],2,3))
    
  })) %>%
  group_by(patient) %>%
  mutate(first.event = event.number == min(event.number)) %>%
  ungroup()


# for the traceless version

# firm.alignment.samples <- tibble.for.samples %>% filter(!prob.contaminant & !all.trace) %>% pull(old.label)
# 
# firm.snp.alignment <- snp.alignment[firm.alignment.patients,]

earliest.samples<- tibble.for.samples %>% filter(!prob.contaminant & first.event) %>% pull(old.label)
earliest.nontrace.samples <- tibble.for.samples %>% filter(!prob.contaminant & first.event & !all.trace) %>% pull(old.label)

# firm.snp.dists <- dist.dna(firm.snp.alignment, model="N", pairwise.deletion = T, as.matrix = T)

earliest.sample.alignment <- snp.alignment[earliest.samples,]
earliest.nontrace.alignment <- snp.alignment[earliest.nontrace.samples,]

earliest.snp.dists <- dist.dna(earliest.sample.alignment, model="N", pairwise.deletion = T, as.matrix = T)
earliest.snp.notrace.dists <- dist.dna(earliest.nontrace.alignment, model="N", pairwise.deletion = T, as.matrix = T)

earliest.ancestry.pp.mat <- matrix(NA, length(labels(earliest.sample.alignment)), length(labels(earliest.sample.alignment)))
earliest.weights.mat <- matrix(NA, length(labels(earliest.sample.alignment)), length(labels(earliest.sample.alignment)))
rownames(earliest.ancestry.pp.mat) <- colnames(earliest.ancestry.pp.mat) <- labels(earliest.sample.alignment)
rownames(earliest.weights.mat) <- colnames(earliest.weights.mat) <- labels(earliest.sample.alignment)

total <- (length(labels(earliest.sample.alignment)) * (length(labels(earliest.sample.alignment))-1))/2

counter <- 1

for(seq.1.no in 1:length(labels(earliest.sample.alignment))){
  for(seq.2.no in 1:length(labels(earliest.sample.alignment))){
    if(seq.1.no < seq.2.no){
      if(counter %% 100 == 0){
        cat("Done ",counter," of ",total,"\n", sep="")
      }
      seq.1 <- labels(earliest.sample.alignment)[seq.1.no]
      seq.2 <- labels(earliest.sample.alignment)[seq.2.no]
      
      if(seq.1 != "NA" & seq.2 != "NA"){
        
        pat.1 <- substr(seq.1, 1, 4)
        pat.2 <- substr(seq.2, 1, 4)
        
        if(pat.1 != pat.2){
          if(!is.na(anc.matrix.ancestry[pat.1, pat.2])){
            earliest.ancestry.pp.mat[seq.1, seq.2] <- anc.matrix.ancestry[pat.1, pat.2]
          } else if(!is.na(anc.matrix.ancestry[pat.2, pat.1])){
            earliest.ancestry.pp.mat[seq.1, seq.2] <- anc.matrix.ancestry[pat.2, pat.1]
          }
  
          
          pat.1.seq.count <- length(which(substr(labels(earliest.sample.alignment), 1, 4) == pat.1))
          pat.2.seq.count <- length(which(substr(labels(earliest.sample.alignment), 1, 4) == pat.2))
          
          earliest.weights.mat[seq.1, seq.2] <- 1/(pat.1.seq.count*pat.2.seq.count)
          
        }
        counter <- counter + 1
      }
    }
  }
}

earliest.notrace.probpair.pp.mat <- matrix(NA, length(labels(earliest.nontrace.alignment)), length(labels(earliest.nontrace.alignment)))
earliest.notrace.ancestry.pp.mat <- matrix(NA, length(labels(earliest.nontrace.alignment)), length(labels(earliest.nontrace.alignment)))
earliest.notrace.weights.mat <- matrix(NA, length(labels(earliest.nontrace.alignment)), length(labels(earliest.nontrace.alignment)))
rownames(earliest.notrace.probpair.pp.mat) <- colnames(earliest.notrace.probpair.pp.mat) <- labels(earliest.nontrace.alignment)
rownames(earliest.notrace.ancestry.pp.mat) <- colnames(earliest.notrace.ancestry.pp.mat) <- labels(earliest.nontrace.alignment)
rownames(earliest.notrace.weights.mat) <- colnames(earliest.notrace.weights.mat) <- labels(earliest.nontrace.alignment)

total <- (length(labels(earliest.nontrace.alignment)) * (length(labels(earliest.nontrace.alignment))-1))/2

counter <- 1

for(seq.1.no in 1:length(labels(earliest.nontrace.alignment))){
  for(seq.2.no in 1:length(labels(earliest.nontrace.alignment))){
    if(seq.1.no < seq.2.no){
      if(counter %% 100 == 0){
        cat("Done ",counter," of ",total,"\n", sep="")
      }
      seq.1 <- labels(earliest.nontrace.alignment)[seq.1.no]
      seq.2 <- labels(earliest.nontrace.alignment)[seq.2.no]
      
      if(seq.1 != "NA" & seq.2 != "NA"){
        
        pat.1 <- substr(seq.1, 1, 4)
        pat.2 <- substr(seq.2, 1, 4)
        
        if(pat.1 != pat.2){
          if(!is.na(anc.matrix.probable.pair[pat.1, pat.2])){
            earliest.notrace.probpair.pp.mat[seq.1, seq.2] <- anc.matrix.probable.pair[pat.1, pat.2]
          } else if(!is.na(anc.matrix.probable.pair[pat.2, pat.1])){
            earliest.notrace.probpair.pp.mat[seq.1, seq.2] <- anc.matrix.probable.pair[pat.2, pat.1]
          }
          
          if(!is.na(anc.matrix.probable.pair[pat.1, pat.2])){
            earliest.notrace.ancestry.pp.mat[seq.1, seq.2] <- anc.matrix.ancestry[pat.1, pat.2]
          } else if(!is.na(anc.matrix.probable.pair[pat.2, pat.1])){
            earliest.notrace.ancestry.pp.mat[seq.1, seq.2] <- anc.matrix.ancestry[pat.2, pat.1]
          }
          
          
          pat.1.seq.count <- length(which(substr(labels(earliest.nontrace.alignment), 1, 4) == pat.1))
          pat.2.seq.count <- length(which(substr(labels(earliest.nontrace.alignment), 1, 4) == pat.2))
          
          earliest.notrace.weights.mat[seq.1, seq.2] <- 1/(pat.1.seq.count*pat.2.seq.count)
          
        }
        counter <- counter + 1
      }
    }
  }
}




earliest.snp.dists.long <- earliest.snp.dists[which(!is.na(earliest.ancestry.pp.mat), arr.ind = T)]
earliest.snp.notrace.dists.long <- earliest.snp.notrace.dists[which(!is.na(earliest.notrace.ancestry.pp.mat), arr.ind = T)]


earliest.ancestry.pp.mat.long <- earliest.ancestry.pp.mat[which(!is.na(earliest.ancestry.pp.mat), arr.ind = T)]
earliest.notrace.probpair.pp.mat.long <- earliest.notrace.probpair.pp.mat[which(!is.na(earliest.notrace.probpair.pp.mat), arr.ind = T)]
earliest.notrace.ancestry.pp.mat.long <- earliest.notrace.ancestry.pp.mat[which(!is.na(earliest.notrace.ancestry.pp.mat), arr.ind = T)]

earliest.weights.mat.long <- earliest.weights.mat[which(!is.na(earliest.ancestry.pp.mat), arr.ind = T)]
earliest.notrace.weights.mat.long <- earliest.notrace.weights.mat[which(!is.na(earliest.notrace.ancestry.pp.mat), arr.ind = T)]

earliest.sample.comparison.df <- tibble(snp.distance = earliest.snp.dists.long,
                                        post.prob.ancestry = earliest.ancestry.pp.mat.long,
                                        weight = earliest.weights.mat.long) %>%
  mutate(classified.as = map_dbl(post.prob.ancestry, function(x){
    if(x >= 95){
      1
    } else if(x<=5){
      -1
    } else {
      NA
    }
  })) %>%
  filter(!is.na(classified.as))

earliest.notrace.sample.comparison.df <- tibble(snp.distance = earliest.snp.notrace.dists.long,
                                                post.prob.ancestry = earliest.notrace.ancestry.pp.mat.long,
                                                post.prob.probpair = earliest.notrace.probpair.pp.mat.long,
                                                weight = earliest.notrace.weights.mat.long) %>%
  mutate(classified.as = map2_dbl(post.prob.ancestry, post.prob.probpair, function(x, y){
    if(y >= 95){
      1
    } else if(x<=5){
      -1
    } else {
      NA
    }
  })) %>%
  filter(!is.na(classified.as))



p.cutoff.upper <- 0.95
p.cutoff.lower <- 0.05

earliest.stats <- map(0:1000, function(d.thresh){
  tp.f <- sum(earliest.weights.mat.long[which(earliest.ancestry.pp.mat.long >= 100*p.cutoff.upper & earliest.snp.dists.long<=d.thresh)])
  fn.f <- sum(earliest.weights.mat.long[which(earliest.ancestry.pp.mat.long >= 100*p.cutoff.upper & earliest.snp.dists.long>d.thresh)])
  tn.f <- sum(earliest.weights.mat.long[which(earliest.ancestry.pp.mat.long <= 100*p.cutoff.lower & earliest.snp.dists.long>d.thresh)])
  fp.f <- sum(earliest.weights.mat.long[which(earliest.ancestry.pp.mat.long <= 100*p.cutoff.lower & earliest.snp.dists.long<=d.thresh)])
  
  list(dist.threshold = d.thresh,
       sensitivity = tp.f/(tp.f+fn.f),
       specificity = tn.f/(tn.f+fp.f))
  
}) %>% bind_rows  %>%
  add_column(version = "<= 1 lineage transmitted")

earliest.notrace.stats <- map(0:1000, function(d.thresh){
  tp.f <- sum(earliest.notrace.weights.mat.long[which(earliest.notrace.probpair.pp.mat.long >= 100*p.cutoff.upper & earliest.snp.notrace.dists.long<=d.thresh)])
  fn.f <- sum(earliest.notrace.weights.mat.long[which(earliest.notrace.probpair.pp.mat.long >= 100*p.cutoff.upper & earliest.snp.notrace.dists.long>d.thresh)])
  tn.f <- sum(earliest.notrace.weights.mat.long[which(earliest.notrace.ancestry.pp.mat.long <= 100*p.cutoff.lower & earliest.snp.notrace.dists.long>d.thresh)])
  fp.f <- sum(earliest.notrace.weights.mat.long[which(earliest.notrace.ancestry.pp.mat.long <= 100*p.cutoff.lower & earliest.snp.notrace.dists.long<=d.thresh)])
  
  list(dist.threshold = d.thresh,
       sensitivity = tp.f/(tp.f+fn.f),
       specificity = tn.f/(tn.f+fp.f))
  
}) %>% bind_rows %>%
  add_column(version = "<= 2 lineages transmitted")

all.stats <- bind_rows(earliest.stats, earliest.notrace.stats)

all.stats.gathered <- gather(all.stats, statistic, value, 2:3) 


ss.plot <- ggplot(data=all.stats.gathered) + 
  geom_line(aes(x=dist.threshold, y=value, col=statistic), size=1, lineend="round", alpha=0.8) + 
  theme_bw() +
  facet_wrap(~version) +
  xlim(c(0,60)) +
  scale_color_brewer(palette="Set1", name="Statistic", labels=c("Sensitivity", "Specificity")) +
  xlab("SNP threshold") +
  ylab("Statistic value")

ggsave("Figure9A.pdf", width=9, height=4)

ss.plot <- ggplot(data=all.stats.gathered) + 
  geom_line(aes(x=dist.threshold, y=value, col=statistic), size=1, lineend="round", alpha=0.8) + 
  theme_bw() +
  facet_wrap(~version) +
  xlim(c(0,500)) +
  scale_color_brewer(palette="Set1", name="Statistic", labels=c("Sensitivity", "Specificity")) +
  xlab("SNP threshold") +
  ylab("Statistic value")

ggsave("Figure9S1.pdf", width=9, height=4)


all.stats$tpr <- all.stats$sensitivity
all.stats$fpr <- 1-all.stats$specificity

all.stats$pythagoras <- sapply(1:nrow(all.stats), function(x){
  sqrt(
    (1-all.stats$tpr[x])^2 + (all.stats$fpr[x])^2
  )
})


for.annot <- all.stats[which(all.stats$dist.threshold %in% c(0, 10, 25, 50, 100, 150, 200, 300, 400, 1000)),]
for.annot$rowno <- 1:10



roc.plot <- ggplot(data=all.stats) + 
  geom_path(aes(x=fpr, y=tpr), size=1, lineend = "round") + 
  geom_point(data = for.annot, aes(x=fpr, y=tpr, fill=rowno), size=2, shape=21) +
  #geom_point(data = sens.spec.df.w[which(sens.spec.df.w$threshold %in% c(0, 10, 20, 30, 40, 50, 100, 200, 300, 400)),], aes(x=fpr, y=tpr), col="red", size=0.75) +
  scale_fill_gradientn(guide="legend", breaks=c(1:10), colours=c("red", "white", "green"), labels=c(0, 10, 25, 50, 100, 150, 200, 300, 400, 1000), name="SNP threshold") +
  theme_bw() +
  facet_wrap(~version) + 
  xlab("1 - Specificity") + 
  ylab("Sensitivity") +
  scale_x_continuous(limits=c(0,1)) + 
  scale_y_continuous(limits=c(0,1)) 

roc.plot

ggsave("Figure9B.pdf", width=9, height=4)

